{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# softmax回归的从零开始实现\n",
    ":label:`sec_softmax_scratch`\n",
    "\n",
    "就像我们从零开始实现线性回归一样，我们认为softmax回归也是重要的基础，因此你应该知道实现softmax的细节。我们使用刚刚在 :numref:`sec_fashion_mnist` 中引入的Fashion-MNIST数据集，并设置数据迭代器的批量大小为256。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "origin_pos": 1,
    "tab": [
     "mxnet"
    ]
   },
   "outputs": [],
   "source": [
    "%load ../utils/djl-imports\n",
    "%load ../utils/plot-utils.ipynb\n",
    "%load ../utils/Training.java\n",
    "\n",
    "import ai.djl.basicdataset.cv.classification.*;"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "origin_pos": 4,
    "tab": [
     "mxnet"
    ]
   },
   "outputs": [],
   "source": [
    "int batchSize = 256;\n",
    "boolean randomShuffle = true;\n",
    "\n",
    "// get training and validation dataset\n",
    "FashionMnist trainingSet = FashionMnist.builder()\n",
    "        .optUsage(Dataset.Usage.TRAIN)\n",
    "        .setSampling(batchSize, randomShuffle)\n",
    "        .optLimit(Long.getLong(\"DATASET_LIMIT\", Long.MAX_VALUE))\n",
    "        .build();\n",
    "\n",
    "FashionMnist validationSet = FashionMnist.builder()\n",
    "        .optUsage(Dataset.Usage.TEST)\n",
    "        .setSampling(batchSize, false)\n",
    "        .optLimit(Long.getLong(\"DATASET_LIMIT\", Long.MAX_VALUE))\n",
    "        .build();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 初始化模型参数\n",
    "\n",
    "和之前线性回归的例子一样，这里的每个样本都将用固定长度的向量表示。原始数据集中的每个样本都是 $28 \\times 28$ 的图像。在本节中，我们将展平每个图像，把它们看作长度为784的向量。在后面的章节中，将讨论能够利用图像空间结构的更为复杂的策略，但现在我们暂时只把每个像素位置看作一个特征。\n",
    "\n",
    "回想一下，在softmax回归中，我们的输出与类别一样多。因为我们的数据集有10个类别，所以网络输出维度为 10。因此，权重将构成一个 $784 \\times 10$ 的矩阵，偏置将构成一个 $1 \\times 10$ 的行向量。与线性回归一样，我们将使用正态分布初始化我们的权重 `W`，偏置初始化为0。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "origin_pos": 6,
    "tab": [
     "mxnet"
    ]
   },
   "outputs": [],
   "source": [
    "int numInputs = 784;\n",
    "int numOutputs = 10;\n",
    "\n",
    "NDManager manager = NDManager.newBaseManager();\n",
    "NDArray W = manager.randomNormal(0, 0.01f, new Shape(numInputs, numOutputs), DataType.FLOAT32);\n",
    "NDArray b = manager.zeros(new Shape(numOutputs), DataType.FLOAT32);\n",
    "NDList params = new NDList(W, b);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 定义softmax操作\n",
    "\n",
    "在实现softmax回归模型之前，让我们简要地回顾一下`sum()`运算符如何沿着`NDArray`中的特定维度工作，如 :numref:`subseq_lin-alg-reduction` 和 :numref:`subseq_lin-alg-non-reduction` 所述。给定一个矩阵`X`，我们可以对所有元素求和（默认情况下），也可以只求同一个轴上的元素，即同一列`new int[]{0}`或同一行`new int[]{0, 1}`。如果 `X` 是一个形状为 `(2, 3)` 的`NDArray`，我们对列进行求和，则结果将是一个具有形状 `(3,)` 的向量。当调用`sum`运算符时，我们可以指定保持在原始 `NDArray` 的轴数，而不折叠求和的维度。这将产生一个具有形状 `(1, 3)` 的二维 `NDArray` 。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "origin_pos": 11,
    "tab": [
     "mxnet"
    ]
   },
   "outputs": [],
   "source": [
    "NDArray X = manager.create(new int[][]{{1, 2, 3}, {4, 5, 6}});\n",
    "System.out.println(X.sum(new int[]{0}, true));\n",
    "System.out.println(X.sum(new int[]{1}, true));\n",
    "System.out.println(X.sum(new int[]{0, 1}, true));"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们现在已经准备好实现softmax操作了。回想一下，softmax 由三个步骤组成：\n",
    "\n",
    "1. 对每个项求幂（使用`exp()`）。\n",
    "2. 对每一行求和（小批量中每个样本是一行），得到每个样本的归一化常数。\n",
    "3. 将每一行除以其归一化常数，确保结果的和为1。\n",
    "\n",
    "在查看代码之前，让我们回顾一下这个表达式：\n",
    "\n",
    "$$\n",
    "\\mathrm{softmax}(\\mathbf{X})_{ij} = \\frac{\\exp(\\mathbf{X}_{ij})}{\\sum_k \\exp(\\mathbf{X}_{ik})}.\n",
    "$$\n",
    "\n",
    "分母或归一化常数，有时也称为*配分函数*（其对数称为对数-配分函数）。该名称的起源来自 [统计物理学](https://en.wikipedia.org/wiki/Partition_function_(statistical_mechanics))中一个模拟粒子群分布的方程。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "origin_pos": 13,
    "tab": [
     "mxnet"
    ]
   },
   "outputs": [],
   "source": [
    "public NDArray softmax(NDArray X) {\n",
    "    NDArray Xexp = X.exp();\n",
    "    NDArray partition = Xexp.sum(new int[]{1}, true);\n",
    "    return Xexp.div(partition); // 这里应用了广播机制\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "正如你所看到的，对于任何随机输入，我们将每个元素变成一个非负数。此外，依据概率原理，每行总和为1。注意，虽然这在数学上看起来是正确的，但我们在代码实现中有点草率。矩阵中的非常大或非常小的元素可能造成数值上溢或下溢，但我们没有采取措施来防止这点。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "origin_pos": 16,
    "tab": [
     "mxnet"
    ]
   },
   "outputs": [],
   "source": [
    "NDArray X = manager.randomNormal(new Shape(2, 5));\n",
    "NDArray Xprob = softmax(X);\n",
    "System.out.println(Xprob);\n",
    "System.out.println(Xprob.sum(new int[]{1}));"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 定义模型\n",
    "\n",
    "现在我们已经定义了softmax操作，我们可以实现softmax回归模型。下面的代码定义了输入如何通过网络映射到输出。注意，在将数据传递到我们的模型之前，我们使用 `reshape()` 函数将每张原始图像展平为向量。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "origin_pos": 19,
    "tab": [
     "mxnet"
    ]
   },
   "outputs": [],
   "source": [
    "// We need to wrap `net()` in a class so that we can reference the method\n",
    "// and pass it as a parameter to a function or save it in a variable\n",
    "public class Net {\n",
    "    public static NDArray net(NDArray X) {\n",
    "        NDArray currentW = params.get(0);\n",
    "        NDArray currentB = params.get(1);\n",
    "        return softmax(X.reshape(new Shape(-1, numInputs)).dot(currentW).add(currentB));\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 定义损失函数\n",
    "\n",
    "接下来，我们需要实现 :numref:`sec_softmax` 中引入的交叉熵损失函数。这可能是深度学习中最常见的损失函数，因为目前分类问题的数量远远超过回归问题。\n",
    "\n",
    "回顾一下，交叉熵采用真实标签的预测概率的负对数似然。我们不需要使用`Java`的`for`循环迭代预测（这往往是低效的）。\n",
    "\n",
    "我们可以使用 `NDIndex` 表达式选择 `NDArray` 索引的元素，下面，我们创建一个数据`yHat`，其中包含2个样本在3个类别的预测概率，我们知道在第一个样本中，第一类是正确的预测，而在第二个样本中，第三类是正确的预测。我们可以使用 \":, {}\" 表达式选择正确的预测。 NDArray: `{0, 2}` 作为 `yHat` 中概率的索引，表示选择第一个样本中第 0 列和第二个样本中 2 列。\n",
    "\n",
    "注意：创建 `NDIndex` 时使用的 `NDArray` 的数据类型必须是 `int` 或 `long`。你需要使用 `toType()` 函数将非整形 `NDArray` 转成 `DataType.INT32` 或 `DataType.INT64`。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "origin_pos": 21,
    "tab": [
     "mxnet"
    ]
   },
   "outputs": [],
   "source": [
    "NDArray yHat = manager.create(new float[][]{{0.1f, 0.3f, 0.6f}, {0.3f, 0.2f, 0.5f}});\n",
    "yHat.get(new NDIndex(\":, {}\", manager.create(new int[]{0, 2})));"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "现在我们只需一行代码就可以实现交叉熵损失函数。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "origin_pos": 24,
    "tab": [
     "mxnet"
    ]
   },
   "outputs": [],
   "source": [
    "// Cross Entropy only cares about the target class's probability\n",
    "// Get the column index for each row\n",
    "public class LossFunction {\n",
    "    public static NDArray crossEntropy(NDArray yHat, NDArray y) {\n",
    "        // Here, y is not guranteed to be of datatype int or long\n",
    "        // and in our case we know its a float32.\n",
    "        // We must first convert it to int or long(here we choose int)\n",
    "        // before we can use it with NDIndex to \"pick\" indices. \n",
    "        // It also takes in a boolean for returning a copy of the existing NDArray\n",
    "        // but we don't want that so we pass in `false`.\n",
    "        return yHat.get(new NDIndex(\":, {}\", y.toType(DataType.INT32, false))).log().neg();\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 分类准确率\n",
    "\n",
    "给定预测概率分布 `yHat`，当我们必须输出硬预测（hard prediction）时，我们通常选择预测概率最高的类。许多应用都要求我们做出选择。如Gmail必须将电子邮件分为“Primary（主要）”、“Social（社交）”、“Updates（更新）”或“Forums（论坛）”。它可能在内部估计概率，但最终它必须在类中选择一个。\n",
    "\n",
    "当预测与标签分类 `y` 一致时，它们是正确的。分类准确率即正确预测数量与总预测数量之比。虽然直接优化准确率可能很困难（因为准确率的计算不可导），但准确率通常是我们最关心的性能衡量标准，我们在训练分类器时几乎总是会报告它。\n",
    "\n",
    "为了计算准确率，我们执行以下操作。首先，如果 `yHat` 是矩阵，那么假定第二个维度存储每个类的预测分数。我们使用 `yHat.argMax()` 获得每行中最大元素的索引来获得预测类别。然后我们将预测类别与真实 `y` 元素进行比较。由于 `eq()` 函数要求数据类型也一致，因此我们将 `yHat` 的数据类型转换为与 `y` 相同的数据类型。结果是一个包含 0（错）和 1（对）的 `NDArray`。进行求和会得到正确预测的数量。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "origin_pos": 27,
    "tab": [
     "mxnet"
    ]
   },
   "outputs": [],
   "source": [
    "// Saved in the utils for later use\n",
    "public float accuracy(NDArray yHat, NDArray y) {\n",
    "    // Check size of 1st dimension greater than 1\n",
    "    // to see if we have multiple samples\n",
    "    if (yHat.getShape().size(1) > 1) {\n",
    "        // Argmax gets index of maximum args for given axis 1\n",
    "        // Convert yHat to same dataType as y (int32)\n",
    "        // Sum up number of true entries\n",
    "        return yHat.argMax(1).toType(DataType.INT32, false).eq(y.toType(DataType.INT32, false))\n",
    "            .sum().toType(DataType.FLOAT32, false).getFloat();\n",
    "    }\n",
    "    return yHat.toType(DataType.INT32, false).eq(y.toType(DataType.INT32, false))\n",
    "        .sum().toType(DataType.FLOAT32, false).getFloat();\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们将继续使用之前定义的变量 `yHat` 和 `y` 分别作为预测的概率分布和标签。我们可以看到，第一个样本的预测类别是2（该行的最大元素为0.6，索引为2），这与实际标签0不一致。第二个样本的预测类别是2（该行的最大元素为0.5，索引为 2），这与实际标签2一致。因此，这两个样本的分类准确率率为0.5。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "origin_pos": 29,
    "tab": [
     "mxnet"
    ]
   },
   "outputs": [],
   "source": [
    "NDArray y = manager.create(new int[]{0,2});\n",
    "accuracy(yHat, y) / y.size();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "同样，对于任意数据迭代器 `dataIterator` 可访问的数据集，我们可以评估在任意模型 `net` 的准确率。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import java.util.function.UnaryOperator; \n",
    "import java.util.function.BinaryOperator; "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "origin_pos": 31,
    "tab": [
     "mxnet"
    ]
   },
   "outputs": [],
   "source": [
    "// Saved in the utils for future use\n",
    "public float evaluateAccuracy(UnaryOperator<NDArray> net, Iterable<Batch> dataIterator) {\n",
    "    Accumulator metric = new Accumulator(2);  // numCorrectedExamples, numExamples\n",
    "    for (Batch batch : dataIterator) {\n",
    "        NDArray X = batch.getData().head();\n",
    "        NDArray y = batch.getLabels().head();\n",
    "        metric.add(new float[]{accuracy(net.apply(X), y), (float)y.size()});\n",
    "        batch.close();\n",
    "    }\n",
    "    return metric.get(0) / metric.get(1);\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "这里 `Accumulator` 是一个实用程序类，用于对多个变量进行累加。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "origin_pos": 34,
    "tab": [
     "mxnet"
    ]
   },
   "outputs": [],
   "source": [
    "// Saved in utils for future use        \n",
    "/* Sum a list of numbers over time */\n",
    "public class Accumulator {\n",
    "    float[] data;\n",
    "    \n",
    "    public Accumulator(int n) {\n",
    "        data = new float[n];\n",
    "    }\n",
    "    \n",
    "    /* Adds a set of numbers to the array */\n",
    "    public void add(float[] args) {\n",
    "        for (int i = 0; i < args.length; i++) {\n",
    "            data[i] += args[i];\n",
    "        }\n",
    "    }\n",
    "\n",
    "    /* Resets the array */\n",
    "    public void reset() {\n",
    "        Arrays.fill(data, 0f);\n",
    "    }\n",
    "\n",
    "    /* Returns the data point at the given index */\n",
    "    public float get(int index) {\n",
    "        return data[index];\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "由于我们使用随机权重初始化 `net` 模型，因此该模型的准确率应接近于随机猜测。例如在有10个类别情况下的准确率为0.1。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "origin_pos": 36,
    "tab": [
     "mxnet"
    ]
   },
   "outputs": [],
   "source": [
    "evaluateAccuracy(Net::net, validationSet.getData(manager));"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 训练\n",
    "\n",
    "如果你看过 :numref:`sec_linear_scratch` 中的线性回归实现，softmax回归的训练过程代码应该看起来非常熟悉。在这里，我们重构训练过程的实现以使其可重复使用。首先，我们定义一个函数来训练一个迭代周期。请注意，`updater()` 是更新模型参数的常用函数，它接受批量大小作为参数。它可以是封装的`Traning.sgd()`函数，也可以是框架的内置优化函数。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "origin_pos": 38,
    "tab": [
     "mxnet"
    ]
   },
   "outputs": [],
   "source": [
    "@FunctionalInterface\n",
    "public static interface ParamConsumer {\n",
    "     void accept(NDList params, float lr, int batchSize);\n",
    "}\n",
    "\n",
    "public float[] trainEpochCh3(UnaryOperator<NDArray> net, Iterable<Batch> trainIter, BinaryOperator<NDArray> loss, ParamConsumer updater) {\n",
    "    Accumulator metric = new Accumulator(3); // trainLossSum, trainAccSum, numExamples\n",
    "    \n",
    "    // Attach Gradients\n",
    "    for (NDArray param : params) {\n",
    "        param.setRequiresGradient(true);\n",
    "    }\n",
    "    \n",
    "    for (Batch batch : trainIter) {\n",
    "        NDArray X = batch.getData().head();\n",
    "        NDArray y = batch.getLabels().head();\n",
    "        X = X.reshape(new Shape(-1, numInputs));\n",
    "            \n",
    "        try (GradientCollector gc = Engine.getInstance().newGradientCollector()) {\n",
    "            // Minibatch loss in X and y\n",
    "            NDArray yHat = net.apply(X);\n",
    "            NDArray l = loss.apply(yHat, y);\n",
    "            gc.backward(l);  // Compute gradient on l with respect to w and b\n",
    "            metric.add(new float[]{l.sum().toType(DataType.FLOAT32, false).getFloat(), \n",
    "                                   accuracy(yHat, y), \n",
    "                                   (float)y.size()});\n",
    "            gc.close();\n",
    "        }\n",
    "        updater.accept(params, lr, batch.getSize());  // Update parameters using their gradient\n",
    "        \n",
    "        batch.close();\n",
    "    }\n",
    "    // Return trainLoss, trainAccuracy\n",
    "    return new float[]{metric.get(0) / metric.get(2), metric.get(1) / metric.get(2)}; \n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在展示训练函数的实现之前，我们定义一个在动画中绘制数据的实用程序类。它能够简化本书其余部分的代码。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "origin_pos": 42,
    "tab": [
     "mxnet"
    ]
   },
   "outputs": [],
   "source": [
    "import tech.tablesaw.api.Row;\n",
    "import tech.tablesaw.columns.Column;\n",
    "\n",
    "// Saved in utils\n",
    "/* Animates a graph with real-time data. */\n",
    "class Animator {\n",
    "    private String id; // Id reference of graph(for updating graph)\n",
    "    private Table data; // Data Points\n",
    "    \n",
    "    public Animator() {\n",
    "        id = \"\";\n",
    "        \n",
    "        // Incrementally plot data\n",
    "        data = Table.create(\"Data\")\n",
    "        .addColumns(\n",
    "            FloatColumn.create(\"epoch\", new float[]{}),\n",
    "            FloatColumn.create(\"value\", new float[]{}),\n",
    "            StringColumn.create(\"metric\", new String[]{})\n",
    "        );\n",
    "    }\n",
    "\n",
    "    // Add a single metric to the table\n",
    "    public void add(float epoch, float value, String metric) {\n",
    "        Row newRow = data.appendRow();\n",
    "        newRow.setFloat(\"epoch\", epoch);\n",
    "        newRow.setFloat(\"value\", value);\n",
    "        newRow.setString(\"metric\", metric);\n",
    "    }\n",
    "    \n",
    "    // Add accuracy, train accuracy, and train loss metrics for a given epoch\n",
    "    // Then plot it on the graph\n",
    "    public void add(float epoch, float accuracy, float trainAcc, float trainLoss) {\n",
    "        add(epoch, trainLoss, \"train loss\");\n",
    "        add(epoch, trainAcc, \"train accuracy\");\n",
    "        add(epoch, accuracy, \"test accuracy\");\n",
    "        show();\n",
    "    }\n",
    "    \n",
    "    // Display the graph\n",
    "    public void show() {\n",
    "        if (id.equals(\"\")) {\n",
    "            id = display(LinePlot.create(\"\", data, \"epoch\", \"value\", \"metric\"));\n",
    "            return;\n",
    "        }\n",
    "        update();\n",
    "    }\n",
    "    \n",
    "    // Update the graph\n",
    "    public void update() {\n",
    "        updateDisplay(id, LinePlot.create(\"\", data, \"epoch\", \"value\", \"metric\"));\n",
    "    }\n",
    "    \n",
    "    // Returns the column at the given index\n",
    "    // if it is a float column\n",
    "    // Otherwise returns null\n",
    "    public float[] getY(Table t, int index) {\n",
    "        Column c = t.column(index);\n",
    "        if (c.type() == ColumnType.FLOAT) {\n",
    "            float[] newArray = new float[c.size()];\n",
    "            System.arraycopy(c.asList().toArray(), 0, newArray, 0, c.size());\n",
    "            return newArray;\n",
    "        }\n",
    "        return null;\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "接下来我们实现一个训练函数，它会在`trainingSet` 访问到的训练数据集上训练一个模型`net`。该训练函数将会运行多个迭代周期（由`numEpochs`指定）。在每个迭代周期结束时，利用 `validationSet` 访问到的测试数据集对模型进行评估。我们将利用 `Animator` 类来可视化训练进度。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "origin_pos": 44,
    "tab": [
     "mxnet"
    ]
   },
   "outputs": [],
   "source": [
    "public void trainCh3(UnaryOperator<NDArray> net, Dataset trainDataset, Dataset testDataset, \n",
    "                     BinaryOperator<NDArray> loss, int numEpochs, ParamConsumer updater) \n",
    "                                                            throws IOException, TranslateException {\n",
    "    Animator animator = new Animator();\n",
    "    for (int i = 1; i <= numEpochs; i++) {\n",
    "        float[] trainMetrics = trainEpochCh3(net, trainDataset.getData(manager), loss, updater);\n",
    "        float accuracy = evaluateAccuracy(net, testDataset.getData(manager));\n",
    "        float trainAccuracy = trainMetrics[1];\n",
    "        float trainLoss = trainMetrics[0];\n",
    "        \n",
    "        animator.add(i, accuracy, trainAccuracy, trainLoss);\n",
    "        System.out.printf(\"Epoch %d: Test Accuracy: %f\\n\", i, accuracy);\n",
    "        System.out.printf(\"Train Accuracy: %f\\n\", trainAccuracy);\n",
    "        System.out.printf(\"Train Loss: %f\\n\", trainLoss);\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "作为一个从零开始的实现，我们使用 :numref:`sec_linear_scratch` 中定义的小批量随机梯度下降来优化模型的损失函数，设置学习率为 0.1。现在，我们训练模型5个迭代周期。请注意，迭代周期（`numEpochs`）和学习率（`lr`）都是可调节的超参数。通过更改它们的值，我们可以提高模型的分类准确率。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "origin_pos": 46,
    "tab": [
     "mxnet"
    ]
   },
   "outputs": [],
   "source": [
    "int numEpochs = 5;\n",
    "float lr = 0.1f;\n",
    "\n",
    "public class Updater {\n",
    "    public static void updater(NDList params, float lr, int batchSize) {\n",
    "        Training.sgd(params, lr, batchSize);\n",
    "    }\n",
    "}\n",
    "\n",
    "trainCh3(Net::net, trainingSet, validationSet, LossFunction::crossEntropy, numEpochs, Updater::updater);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 预测\n",
    "\n",
    "现在训练已经完成，我们的模型已经准备好对图像进行分类预测。给定一系列图像，我们将比较它们的实际标签（文本输出的第一行）和模型预测（文本输出的第二行）。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "origin_pos": 51,
    "tab": [
     "mxnet"
    ]
   },
   "outputs": [],
   "source": [
    "// Saved in the FashionMnistUtils class for later use\n",
    "// Number should be < batchSize for this function to work properly\n",
    "public void predictCh3(UnaryOperator<NDArray> net, ArrayDataset dataset, int number, NDManager manager) \n",
    "    throws IOException, TranslateException {\n",
    "    final int SCALE = 4;\n",
    "    final int WIDTH = 28;\n",
    "    final int HEIGHT = 28;\n",
    "    \n",
    "    int[] predLabels = new int[number];\n",
    "    \n",
    "    for (Batch batch : dataset.getData(manager)) {\n",
    "        NDArray X = batch.getData().head();\n",
    "        int[] yHat = net.apply(X).argMax(1).toType(DataType.INT32, false).toIntArray();\n",
    "        for (int i = 0; i < number; i++) {\n",
    "            predLabels[i] = yHat[i];\n",
    "        }\n",
    "        break;\n",
    "    }\n",
    "    \n",
    "    FashionMnistUtils.showImages(dataset, predLabels, WIDTH, HEIGHT, SCALE, manager);\n",
    "}\n",
    "\n",
    "/* Note: Uncomment the following line and run to do predictions\n",
    "   The image classification will show up in another window */\n",
    "// predictCh3(Net::net, validationSet, 6, manager);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![Fashion Mnist Predictions.](https://d2l-java-resources.s3.amazonaws.com/img/fashion_mnist_pred.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 小结\n",
    "\n",
    "* 借助 softmax 回归，我们可以训练多分类的模型。\n",
    "* softmax 回归的训练循环与线性回归中的训练循环非常相似：读取数据、定义模型和损失函数，然后使用优化算法训练模型。正如你很快就会发现的那样，大多数常见的深度学习模型都有类似的训练过程。\n",
    "\n",
    "## 练习\n",
    "\n",
    "1. 在本节中，我们直接实现了基于数学定义softmax运算的`softmax`函数。这可能会导致什么问题？提示：尝试计算 $\\exp(50)$ 的大小。\n",
    "1. 本节中的函数 `crossEntropy` 是根据交叉熵损失函数的定义实现的。这个实现可能有什么问题？提示：考虑对数的值域。\n",
    "1. 你可以想到什么解决方案来解决上述两个问题？\n",
    "1. 返回概率最大的标签总是一个好主意吗？例如，医疗诊断场景下你会这样做吗？\n",
    "1. 假设我们希望使用softmax回归来基于某些特征预测下一个单词。词汇量大可能会带来哪些问题?\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Java",
   "language": "java",
   "name": "java"
  },
  "language_info": {
   "codemirror_mode": "java",
   "file_extension": ".jshell",
   "mimetype": "text/x-java-source",
   "name": "Java",
   "pygments_lexer": "java",
   "version": "14.0.2+12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
